package org.zjvis.datascience.service.dag;

import com.alibaba.fastjson.JSONObject;
import org.apache.commons.math3.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.dto.graph.GraphDTO;
import org.zjvis.datascience.common.dto.graph.GraphInstanceDTO;
import org.zjvis.datascience.common.graph.model.GraphAttr;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.graph.*;
import org.zjvis.datascience.service.graph.GraphInstanceService;
import org.zjvis.datascience.service.graph.GraphService;
import org.zjvis.datascience.service.graph.JanusGraphEmbedService;

import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

/**
 * @description Graph任务调度器
 * @date 2021-08-20
 */
public class GraphRunner implements Callable<TaskRunnerResult> {
    private final static Logger logger = LoggerFactory.getLogger(GraphRunner.class);

    private final GraphService graphService;
    private final GraphInstanceService graphInstanceService;
    private final JanusGraphEmbedService janusGraphEmbedService;
    private final TaskInstanceDTO instance;

    private List<CategoryVO> categories;
    private List<EdgeVO> edges;


    private Map<String, CategoryVO> cIdMap;
    private Map<String, EdgeVO> eIdMap;
    private Map<String, Map<String, NodeVO>> cId2nIdMap;
    private Map<String, Map<String, LinkVO>> eId2lIdMap;
    private Long nIdCnt = -1L;
    private Long lIdCnt = -1L;
    private Set<String> cIdHelper;

    private Long graphId;
    private GraphOutputVO outputHelper;

    private Map<String, Map<Object, GraphAttr>> attrDataHelper;
    private Map<String, Map<String, Boolean>>  multiValueHelper;
    private Map<String, Pair<String, CategoryAttrVO>>  attrIdHelper;

    private final String QUERY_COLUMNS_FROM_ONE_TABLE_SQL = "SELECT DISTINCT %s FROM %s ORDER BY %s";

    private final String QUERY_COLUMNS_FROM_JOIN_TABLE_SQL = "SELECT DISTINCT %s FROM %s a INNER JOIN %s b ON a.%s = b.%s ORDER BY %s";

    private final String QUERY_COLUMNS_FROM_JOIN_TABLE_SQL_OPT =
            "SELECT DISTINCT a.attr, b.attr " +
            "FROM " +
            "(SELECT DISTINCT %s AS joinkey, %s AS attr " +
            "FROM %s) a " +
            "INNER JOIN " +
            "(SELECT DISTINCT %s AS joinkey ,%s AS attr " +
            "FROM %s) b " +
            "ON a.joinkey = b.joinkey " +
            "ORDER BY a.attr, b.attr ";

    public GraphRunner(GraphService graphService, GraphInstanceService graphInstanceService, JanusGraphEmbedService janusGraphEmbedService, TaskInstanceDTO instance) {
        this.graphService = graphService;
        this.graphInstanceService = graphInstanceService;
        this.instance = instance;
        this.janusGraphEmbedService = janusGraphEmbedService;
        this.initHelper();
        long t1 = System.currentTimeMillis();
        this.initAttrDataHelper();
        long t2 = System.currentTimeMillis();
        logger.info("initAttrDataHelper in {} ms ", t2-t1);

    }

    @Override
    public TaskRunnerResult call() throws Exception {
        TaskRunnerResult result = null;
        String resultString = "";
        try {
            long t1 = System.currentTimeMillis();
            resultString = loadGraphData();
            long t2 = System.currentTimeMillis();
            logger.info("loadGraphData in {} ms ", t2-t1);
            result = new TaskRunnerResult(0, resultString);
        } catch (Exception e) {
            result = new TaskRunnerResult(500, e.getMessage());
        }
        return result;
    }


    public void initHelper() {
        cIdMap = new HashMap<>();
        eIdMap = new HashMap<>();
        cId2nIdMap = new HashMap<>();
        eId2lIdMap = new HashMap<>();
        cIdHelper = new HashSet<>();
        multiValueHelper = new HashMap<>();
        attrIdHelper = new HashMap<>();
        JSONObject jsonObject = JSONObject.parseObject(instance.getDataJson());
        JSONObject inputInfo = jsonObject.getJSONObject("inputInfo");
        categories = inputInfo.getJSONArray("categories").toJavaList(CategoryVO.class);
        edges = inputInfo.getJSONArray("edges").toJavaList(EdgeVO.class);
        graphId = inputInfo.getLong("graphId");
        categories.forEach(category -> {
            cIdMap.put(category.getId(), category);
            cId2nIdMap.put(category.getId(), new HashMap<>());
            multiValueHelper.put(category.getId(), new HashMap<>());
            attrIdHelper.put(category.getKeyAttr().getId(), new Pair<>(category.getId(), category.getKeyAttr()));
            for (CategoryAttrVO categoryAttr: category.getAttrs()) {
                multiValueHelper.get(category.getId()).put(categoryAttr.getName(), false);
                attrIdHelper.put(categoryAttr.getId(), new Pair<>(category.getId(), categoryAttr));
            }
        });
        edges.forEach(edge -> {
            eIdMap.put(edge.getId(), edge);
            eId2lIdMap.put(edge.getId(), new HashMap<>());
            multiValueHelper.put(edge.getId(), new HashMap<>());
            if (edge.getAttrIds() != null) {
                for (String attrId: edge.getAttrIds()) {
                    CategoryAttrVO categoryAttr = attrIdHelper.get(attrId).getValue();
                    multiValueHelper.get(edge.getId()).put(categoryAttr.getName(), false);
                    if (edge.getAttrs() == null) {
                        edge.setAttrs(new ArrayList<>());
                    }
                    edge.getAttrs().add(categoryAttr);
                }
            }
        });

        outputHelper = new GraphOutputVO();

    }

    public void initAttrDataHelper() {
        attrDataHelper = new HashMap<>();
        for (CategoryVO category: categories){
            Map<Object, GraphAttr> categoryAttrHelper = graphService.queryAttrDataHelper(category);
            attrDataHelper.put(category.getId(), categoryAttrHelper);
        }
    }

    public String loadGraphData() {
        //通过Helper将所有实体的所需属性先找出
        //1.单表，在同一张表中查找实体连接关系
        //2.多表关联，根据已配置的joinConfigure通过join查找实体连接
        //
        JSONObject outputLog = new JSONObject();
        GraphInstanceDTO graphInstance = new GraphInstanceDTO(graphService.queryById(graphId));
        Long graphInstanceId = graphInstanceService.save(graphInstance);


        for (String eId: eIdMap.keySet()) {
            EdgeVO edge = eIdMap.get(eId);
            String srcCid = edge.getSource();
            String tarCid = edge.getTarget();
            String srcTableName = cIdMap.get(srcCid).tableName();
            String tarTableName = cIdMap.get(tarCid).tableName();
            if (srcCid.equals(tarCid)) {
                //自连边
                if (!buildPairFromSelfEdge(srcCid, eId)) {
                    logger.error("buildPairFromJoinTable self error, param = {}", String.join(", ", srcCid, tarCid, eId));
                }
            }
            else if (srcTableName.equals(tarTableName)) {
                if (!buildPairFromOneTable(srcCid, tarCid, eId)) {
                    logger.error("buildPairFromOneTable error, param = {}", String.join(", ", srcCid, tarCid, eId));
                }
            } else {
                if (!buildPairFromJoinTable(srcCid, tarCid, eId)) {
                    logger.error("buildPairFromJoinTable error, param = {}", String.join(", ", srcCid, tarCid, eId));
                }
            }
        }

        //对于没有连接关系的孤立点，单独建立其实例
        //使用已经构建过的attrDataHelper和cIdHelper，找到孤立Category
        Set<String> allCategory = attrDataHelper.keySet();
        allCategory.removeAll(cIdHelper);
        for (String cid : allCategory) {
            buildIsolatedCategory(cid);
        }

        GraphDTO graphDTO = graphService.queryById(graphId);
        JSONObject graphData = JSONObject.parseObject(graphDTO.getDataJson());
        String graphName = graphData.getString("janusGraphName");

        //填充graphInstance
        outputHelper.setId(graphId);
        outputHelper.setJanusGraphName(graphName);
        //取出所有最里层元素组成List
        outputHelper.setNodes(cId2nIdMap.values().stream().flatMap(subMap -> subMap.values().stream()).collect(Collectors.toList()));
        outputHelper.setLinks(eId2lIdMap.values().stream().flatMap(subMap -> subMap.values().stream()).collect(Collectors.toList()));

        logger.info("nodes: {}  link:  {}",  nIdCnt+1, lIdCnt+1);

        graphData.put("nodes", outputHelper.getNodes());
        graphData.put("links", outputHelper.getLinks());
        graphData.put("multiValueHelper", multiValueHelper);
        graphDTO.setDataJson(graphData.toJSONString());
        graphService.update(graphDTO);
        outputLog.put("graphId", graphId);
        outputLog.put("janusGraphName", graphName);
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    graphInstanceService.setStatusRunning(graphInstanceId);
                    long t1 = System.currentTimeMillis();
                    Map<String, String> idMap = write2Janus(graphId);
                    long t2 = System.currentTimeMillis();
                    Long duringTime = t2 - t1;
                    graphInstanceService.setStatusSuccess(graphInstanceId, duringTime, graphName);
                    graphData.put("idMap", idMap);
                    graphDTO.setDataJson(graphData.toJSONString());
                    graphService.update(graphDTO);
                } catch (Exception e) {
                    graphInstanceService.setStatusFail(graphInstanceId, graphName, e.getMessage());
                    logger.error("write2Janus error: {}", e.getMessage());
                }
            }
        });
        thread.start();
        return outputLog.toJSONString();
    }

    public boolean buildPairFromOneTable(String srcCid, String tarCid, String eId) {
        //查询所有实际连接的node实体,返回需要保持连接关系
        CategoryVO srcCategory = cIdMap.get(srcCid);
        CategoryVO tarCategory = cIdMap.get(tarCid);
        if (!srcCategory.tableName().equals(tarCategory.tableName())) {
            logger.error("tableName not equal, table = {}", srcCategory.tableName() + ", " + tarCategory.tableName());
            return false;
        }
        String queryColumnName = String.join(",",
                srcCategory.getKeyAttr().getColumn(),
                tarCategory.getKeyAttr().getColumn());
        String tableName = ToolUtil.alignTableName(srcCategory.tableName(), 0L);
        String sql = String.format(QUERY_COLUMNS_FROM_ONE_TABLE_SQL, queryColumnName, tableName, queryColumnName);
        List<List<Object>> queryResult = graphService.queryColumnsFromTable(sql);
        return buildPairFromQueryResult(srcCid, tarCid, eId, queryResult);
    }

    public boolean buildPairFromJoinTable(String srcCid, String tarCid, String eId) {
        //关系在两表间，根据edge的配置查询实体连接
        CategoryVO srcCategory = cIdMap.get(srcCid);
        CategoryVO tarCategory = cIdMap.get(tarCid);
        EdgeVO edge = eIdMap.get(eId);
        String leftTableName = ToolUtil.alignTableName(srcCategory.tableName(), 0L);
        String rightTableName = ToolUtil.alignTableName(tarCategory.tableName(), 0L);
        JSONObject conf = edge.getJoinConfigure();
        if (conf == null) {
            logger.error("edge joinConfigure error, conf=null");
            return false;
        }
        String leftHeaderName = conf.getString("leftHeaderName");
        String rightHeaderName = conf.getString("rightHeaderName");
        if (leftHeaderName == null || rightHeaderName == null) {
            logger.error("edge joinConfigure error, leftHeaderName = {},  rightHeaderName = {}", leftHeaderName, rightHeaderName);
            return false;
        }
        String sql = String.format(QUERY_COLUMNS_FROM_JOIN_TABLE_SQL_OPT,
                leftHeaderName, srcCategory.getKeyAttr().getColumn(), leftTableName,
                rightHeaderName, tarCategory.getKeyAttr().getColumn(), rightTableName);
        long t1 = System.currentTimeMillis();
        List<List<Object>> queryResult = graphService.queryColumnsFromTable(sql);
        long t2 = System.currentTimeMillis();
        logger.info("queryColumnsFromTable: " + (t2-t1) + " ms" + "  row: " + queryResult.size());
        return buildPairFromQueryResult(srcCid, tarCid, eId, queryResult);
    }

    public boolean buildPairFromSelfEdge(String cid, String eId) {
        //关系在两表间，根据edge的配置查询实体连接
        CategoryVO category = cIdMap.get(cid);
        EdgeVO edge = eIdMap.get(eId);
        String tableName = ToolUtil.alignTableName(category.tableName(), 0L);
        JSONObject conf = edge.getJoinConfigure();
        if (conf == null) {
            logger.error("edge joinConfigure error, conf=null");
            return false;
        }
        String leftHeaderName = conf.getString("leftHeaderName");
        String rightHeaderName = conf.getString("rightHeaderName");
        if (leftHeaderName == null || rightHeaderName == null) {
            logger.error("edge joinConfigure error, leftHeaderName = {},  rightHeaderName = {}", leftHeaderName, rightHeaderName);
            return false;
        }
        String sql = String.format(QUERY_COLUMNS_FROM_JOIN_TABLE_SQL_OPT,
                leftHeaderName, category.getKeyAttr().getColumn(), tableName,
                rightHeaderName, category.getKeyAttr().getColumn(), tableName);
        List<List<Object>> queryResult = graphService.queryColumnsFromTable(sql);
        return buildSelfEdgeFromQueryResult(cid, eId, queryResult);
    }

    public boolean buildPairFromQueryResult(String srcCid, String tarCid, String eId, List<List<Object>> queryResult) {
        //已经构建过的category记录，同一category有多种边关系时不需要重复构建
        boolean srcFlag = cIdHelper.add(srcCid);
        boolean tarFlag = cIdHelper.add(tarCid);
        //同一category下, 同主属性已经构建过的node记录
        Set<Object> srcValueHelper = new HashSet<>();
        Set<Object> tarValueHelper = new HashSet<>();

        CategoryVO srcCategory = cIdMap.get(srcCid);
        CategoryVO tarCategory = cIdMap.get(tarCid);

        for (List<Object> row: queryResult) {
            //每行对应一条实体边
            Object srcValue = row.get(0);
            Object tarValue = row.get(1);

            String srcNid = createNode(srcFlag, srcValueHelper, srcCid, srcCategory, srcValue);
            String tarNid = createNode(tarFlag, tarValueHelper, tarCid, tarCategory, tarValue);
            if (srcNid == null || tarNid == null) {
                logger.error("createNode error, Nid = {}", srcNid + ", " + tarNid);
                return false;
            }

            //边一定是不重复的
            String lId = createLink(srcNid, tarNid, eId);

        }
        return true;
    }

    public boolean buildSelfEdgeFromQueryResult(String cid, String eId, List<List<Object>> queryResult) {
        //已经构建过的category记录，同一category有多种边关系时不需要重复构建
        boolean flag = cIdHelper.add(cid);
        //同一category下, 同主属性已经构建过的node记录
        Set<Object> valueHelper = new HashSet<>();
        Set<Set<String>> edgeHelper = new HashSet<>();

        CategoryVO category = cIdMap.get(cid);

        for (List<Object> row: queryResult) {
            //每行对应一条实体边
            Object srcValue = row.get(0);
            Object tarValue = row.get(1);
//            if (srcValue.equals(tarValue)) {
//                //去除自环
//                continue;
//            }

            String srcNid = createNode(flag, valueHelper, cid, category, srcValue);
            String tarNid = createNode(flag, valueHelper, cid, category, tarValue);
            if (srcNid == null || tarNid == null) {
                logger.error("createNode error, Nid = {}", srcNid + ", " + tarNid);
                return false;
            }
            //自连边有可能重复
            Set<String> tmp = new HashSet<>();
            tmp.add(srcNid);
            tmp.add(tarNid);
            if (edgeHelper.add(tmp)) {
                String lId = createLink(srcNid, tarNid, eId);

            }
        }
        return true;
    }

    public void buildIsolatedCategory(String cid) {
        Set<Object> valueHelper = new HashSet<>();
        CategoryVO category = cIdMap.get(cid);
        String queryColumnName = category.getKeyAttr().getColumn();
        String tableName = ToolUtil.alignTableName(category.tableName(), 0L);
        String sql = String.format(QUERY_COLUMNS_FROM_ONE_TABLE_SQL, queryColumnName, tableName, queryColumnName);
        List<List<Object>> queryResult = graphService.queryColumnsFromTable(sql);
        for (List<Object> row: queryResult) {
            Object keyAttrValue = row.get(0);
            String nid = createNode(true, valueHelper, cid, category, keyAttrValue);
            if (nid == null) {
                logger.error("createIsolatedNode error, cid = {},  keyAttrValue = {}", cid, keyAttrValue);
            }
        }
    }

    public String createNode(boolean flag, Set<Object> valueHelper, String cId, CategoryVO category, Object keyAttrValue) {
        if (flag && valueHelper.add(keyAttrValue)) {
            String nId = idGenerator("n");
            //主属性
            List<AttrVO> attrs = new ArrayList<>();
            attrs.add(AttrVO.builder().
                    key(category.getKeyAttr().getName()).
                    value(keyAttrValue).
                    type(category.getKeyAttr().getType()).
                    build());
            //副属性
            for (CategoryAttrVO categoryAttr: category.getAttrs()) {
                Object val = attrDataHelper.get(cId).get(keyAttrValue).getAttr().get(categoryAttr.getName());
                if (val instanceof Collection) {
                    if (!multiValueHelper.get(cId).get(categoryAttr.getName()) && ((Collection) val).size() > 1) {
                        multiValueHelper.get(cId).put(categoryAttr.getName(), true);
                    }
                }
                attrs.add(AttrVO.builder().
                        key(categoryAttr.getName()).
                        value(val).
                        type(categoryAttr.getType()).
                        build());
            }
            NodeVO node = new NodeVO();
            node.setId(nId);
            node.setOrderId(nIdCnt.intValue());
            node.setLabel(attrDataHelper.get(cId).get(keyAttrValue).getLabelName());
            node.setCategoryId(cId);
            node.setAttrs(attrs);
            node.getStyle().put("fill", category.getStyle().getString("fill"));
            cId2nIdMap.get(cId).put(nId, node);
            logger.debug("createNode: " + node.toString());
            return nId;
        } else {
            //创建过的点找到其实例返回id
            Optional<NodeVO> opt = cId2nIdMap.get(cId).values().stream().filter(node -> node.getAttrs().get(0).getValue().equals(keyAttrValue)).findFirst();
            if (opt.isPresent()) {
                NodeVO node = opt.get();
                return node.getId();
            }
        }
        return null;
    }

    public String createLink(String srcNid, String tarNid, String eId) {
        EdgeVO edge = eIdMap.get(eId);
        String lId = idGenerator("l");
        LinkVO link = new LinkVO();
        link.setId(lId);
        link.setOrderId(lIdCnt.intValue());
        link.setSource(srcNid);
        link.setTarget(tarNid);
        link.setEdgeId(eId);
        link.setLabel(eIdMap.get(eId).getLabel());
        link.setDirected(edge.getStyle()==null || edge.getStyle().getBoolean("endArrow"));
        //set weight
        JSONObject weightCfg = edge.getWeightCfg();
        if (weightCfg != null && !weightCfg.isEmpty() && weightCfg.get("attrId") != null) {
            String attrId = weightCfg.getString("attrId");
            String cid = attrIdHelper.get(attrId).getKey();
            CategoryAttrVO attr = attrIdHelper.get(attrId).getValue();
            String weightNid = cid.equals(edge.getSource()) ? srcNid : tarNid;
            Object keyAttrVal = cId2nIdMap.get(cid).get(weightNid).getAttrs().get(0).getValue();
            Object weight = attrDataHelper.get(cid).get(keyAttrVal).getAttr().get(attr.getName());
            try {
                link.setWeight(Double.parseDouble(weight.toString()));
            } catch (NumberFormatException e) {
                link.setWeight(1.0);
            }
        } else {
            link.setWeight(1.0);
        }

        //边属性
        List<AttrVO> attrs = new ArrayList<>();
        if (edge.getAttrIds() != null) {
            for (String attrId: edge.getAttrIds()) {
                String cid = attrIdHelper.get(attrId).getKey();
                CategoryAttrVO categoryAttr = attrIdHelper.get(attrId).getValue();
                String nid = cid.equals(edge.getSource()) ? srcNid : tarNid;
                Object keyAttrVal = cId2nIdMap.get(cid).get(nid).getAttrs().get(0).getValue();
                Object val = attrDataHelper.get(cid).get(keyAttrVal).getAttr().get(categoryAttr.getName());
                if (val instanceof Collection) {
                    if (!multiValueHelper.get(eId).get(categoryAttr.getName()) && ((Collection) val).size() > 1) {
                        multiValueHelper.get(eId).put(categoryAttr.getName(), true);
                    }
                }
                attrs.add(AttrVO.builder().
                        key(categoryAttr.getName()).
                        value(val).
                        type(categoryAttr.getType()).
                        build());
            }
        }
        link.setAttrs(attrs);
        link.getStyle().put("stroke", edge.getStyle().getString("stroke"));
        eId2lIdMap.get(eId).put(lId, link);
        logger.debug("createLink: " + link.toString());
        return lId;
    }

    public String idGenerator(String prefix) {
        if (prefix.equals("n")) {
            nIdCnt++;
            return String.join("_", "n", graphId.toString(), nIdCnt.toString());
        }
        if (prefix.equals("l")) {
            lIdCnt++;
            return String.join("_", "l", graphId.toString(), lIdCnt.toString());
        }
        return "0";
    }

    public Map<String, String> write2Janus(Long graphId) {
        return janusGraphEmbedService.createAllOneTime(graphId, categories, edges, cIdMap, eIdMap, outputHelper.getNodes(), outputHelper.getLinks(), true);
    }

}
